-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
whisper : fine-tuning grammar functionality #1
Conversation
examples/command/command.cpp
Outdated
@@ -138,6 +139,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con | |||
wparams.language = params.language.c_str(); | |||
wparams.n_threads = params.n_threads; | |||
|
|||
// disable fallback - seems not useful for command recognition | |||
wparams.temperature_inc = 0.0f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to make the command recognition example much more robust, although we want to enable it from time to time to make sure that multi-decoder grammar usage is not broken.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems the reason for this to improve the results is that the fallback logic depends on the average logprob of the generated sequences:
Therefore, I think both "best of" and "beam search" strategies would not work as expected.
We probably need to have some re-normalization of the logprobs after applying the grammar to make this compatible with these.
In any case, I think for now we should focus on greedy sampling without fallbacks and improve later if needed
whisper.cpp
Outdated
|
||
for (const auto & reject : rejects) { | ||
if (logits[reject.id] > 0) { | ||
logits[reject.id] /= params.grammar_penalty; | ||
} else { | ||
logits[reject.id] *= params.grammar_penalty; | ||
} | ||
logprobs[reject.id] -= params.grammar_penalty; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this makes more sense, though I'm still experimenting.
At least for the use case where we want to strongly restrict the output to match the grammar, this works well with large penalty (>100.0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this seems to work well enough - I've been testing this code with 10 - 15 on tiny. I was hoping for a sweet spot that would guide the decoding towards valid strings while leaving obviously non-matching strings unchanged. But so far, I've found that the point that gives satisfying grammar matching still seems too eager to match unrelated words to the grammar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose both variants to fail in some way for different use cases. This penalty is essentially a free-parameter which it's not obvious how to set. A more general way would be to let the user provide a penalty callback function and let them implement the penalization, but I guess for now lets keep it simple
But so far, I've found that the point that gives satisfying grammar matching still seems too eager to match unrelated words to the grammar.
Does the original scaling work better in this scenario?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the original scaling work better in this scenario?
Anecdotally, no, not really.
whisper.cpp
Outdated
// when the grammar does not allow any continuation, we don't want to penalize the EOT token | ||
// TODO: is there are better way to do this? | ||
printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2); | ||
if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) { | ||
logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty; | ||
} | ||
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems we have an issue with EOT handling.
Here I try to penalize EOT if the grammar allows continuation and this seems to help, although it's ugly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's actually the start of that, commented out, at the top of the function. I was just unsure initially if that penalty made sense in whisper land.
The decoded string is complete according to the grammar if some stack is empty, so the following will penalize EOT if the decoded string is an incomplete parse. The grammar might still allow continuation after a complete parse, so if we want to only penalize EOT if there's no continuation at all, I guess we'd switch the check to find a nonempty stack.
diff --git a/whisper.cpp b/whisper.cpp
index 5e3b86a..e796ddd 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -3872,13 +3872,13 @@ static void whisper_suppress_invalid_grammar(
return;
}
- // bool allow_eot = false;
- // for (const auto & stack : grammar.stacks) {
- // if (stack.empty()) {
- // allow_eot = true;
- // break;
- // }
- // }
+ bool allow_eot = false;
+ for (const auto & stack : grammar.stacks) {
+ if (stack.empty()) {
+ allow_eot = true;
+ break;
+ }
+ }
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
std::vector<whisper_grammar_candidate> candidates_grammar;
@@ -3898,10 +3898,8 @@ static void whisper_suppress_invalid_grammar(
logprobs[reject.id] -= params.grammar_penalty;
}
- // when the grammar does not allow any continuation, we don't want to penalize the EOT token
- // TODO: is there are better way to do this?
- printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2);
- if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) {
+ // penalize EOT if grammar is incomplete
+ if (!allow_eot) {
logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty;
}
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
// suppress lang tokens | ||
for (size_t i = 0; i < g_lang.size(); ++i) { | ||
logits[whisper_token_lang(&ctx, i)] = -INFINITY; | ||
} | ||
|
||
// suppress prev token | ||
logits[vocab.token_prev] = -INFINITY; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are irrelevant from the grammar functionality - I think improve the transcription in general.
Not sure why OpenAI don't have these filters in the original implementation (or maybe they do?)
whisper.cpp
Outdated
logprobs[i] = -INFINITY; | ||
} | ||
|
||
whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving this down here improves the results in my tests significantly. The idea is that we keep the logic for sampling a timestamp token intact, irrespective of what the grammar suggests. Only after we have made the decision about the next token being text, we apply the grammar filters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense! This was my first foray into whisper internals so I don't have a nuanced understanding of the decoding process.
Looking pretty good! I've been testing this draft with the chess grammar from my demo and a more complex home assistant grammar (the latter courtesy of GPT), using Chess grammar
Home assistant grammar
|
- option to read grammar from file - add sample grammars for colors and chess moves - fine-tune the performance further
I've been testing with the following 3 commands on this branch and I think the results are really good: ./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red green blue"
# say stuff like "d4 d5 knight to c3"
./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "pawn knight king a1 f5 h6"
./command -m ./models/ggml-base.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands" The one thing that is missing is to somehow filter "false positives". I.e. when you say something completely outside the grammar, to be able to understand that it is not related. |
Yeah, that was my struggle exactly. An early thought I had, which I haven't tried yet, was to abandon grammar sampling if all the tokens suggested by the grammar seem sufficiently unlikely. Perhaps some logit threshold that a token would have to satisfy, or top-k/top-p rank? |
I did a few more updates to the implementation and using the provided tests I get satisfying results. It's not yet "magical" as I imagined it could be (at least with
After these changes, I am using the following commands:
Currently, if I say something completely wrong that does not match the grammar, it usually transcribes an empty string which is good. The performance due to the Beam Search is a bit worse, but I think we will improve it from Let me know if you have any concerns with these changes or if your tests no longer work as expected. |
That all sounds good. I'll merge now and will be able to test this evening. |
This is just a draft. If it works OK, we will merge it to ggerganov#1229
I believe I have improved the grammar behaviour. See comments in the diff below.
At least this works quite robust with the "red", "green", "blue" command example both in English and Bulgarian
I want to play with some more sophisticated experiments and also to make sure I didn't break something else.
Let me know what you think about these changes - if you give this branch a try, you might want to disable the verbose prints in the
#if 1